{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "026ebf6e",
   "metadata": {},
   "source": [
    "# Using the BRepNet input features as a face and edge similarity metric\n",
    "This example notebook shows how the BRepNet input features can be used as a similarity metric for finding similar faces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6e261994",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "if os.path.isfile('../models/brepnet.py'):\n",
    "    # We are in the notebooks directory.  Change to the root\n",
    "    os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d0909d10",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "# This viewer allows you to visualize the results\n",
    "from visualization.jupyter_segmentation_viewer import JupyterSegmentationViewer\n",
    "\n",
    "# Useful functions for loading data\n",
    "import utils.data_utils as data_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8be1c9b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here is the path to some example step files for us to convert\n",
    "step_folder = Path(\"./example_files/step_examples\")\n",
    "\n",
    "\n",
    "# We will also need to know the feature standardization for the dataset used to train the model\n",
    "# This is found in the dataset file created by pipeline/build_dataset_file.py or pipeline/quickstart.py\n",
    "feature_standardization_path = Path(\"./example_files/feature_standardization/s2.0.0_step_all_features.json\")\n",
    "standardization_data = data_utils.load_json_data(feature_standardization_path)\n",
    "feature_standardization = standardization_data[\"feature_standardization\"]\n",
    "\n",
    "# This function to standardize features is similar to what we do in the dataloader\n",
    "def standardize_features(feature_tensor, stats):\n",
    "    num_features = len(stats)\n",
    "    assert feature_tensor.shape[1] == num_features\n",
    "    means = np.zeros(num_features)\n",
    "    sds = np.zeros(num_features)\n",
    "    eps = 1e-7\n",
    "    for index, s in enumerate(stats):\n",
    "        assert s[\"standard_deviation\"] > eps, \"Feature has zero standard deviation\"\n",
    "        means[index] = s[\"mean\"]\n",
    "        sds[index] = s[\"standard_deviation\"]\n",
    "\n",
    "    # We need to broadcast means and sds over the number of entities\n",
    "    means_x = np.expand_dims(means, axis=0)\n",
    "    sds_x = np.expand_dims(sds, axis=0)\n",
    "    feature_tensor_zero_mean = feature_tensor - means_x\n",
    "    feature_tensor_standadized = feature_tensor_zero_mean / sds_x\n",
    "\n",
    "    # Convert the tensors to floats after standardization \n",
    "    return feature_tensor_standadized\n",
    "    \n",
    "def standarize_data(data, feature_standardization):\n",
    "    data[\"face_features\"] = standardize_features(data[\"face_features\"], feature_standardization[\"face_features\"])\n",
    "    data[\"edge_features\"] = standardize_features(data[\"edge_features\"], feature_standardization[\"edge_features\"])\n",
    "    data[\"coedge_features\"] = standardize_features(data[\"coedge_features\"], feature_standardization[\"coedge_features\"])\n",
    "    return data\n",
    "    \n",
    "def find_faces_to_edges(coedge_to_face, coedge_to_edge):\n",
    "    faces_to_edges_dict = {}\n",
    "    for coedge_index in range(coedge_to_face.shape[0]):\n",
    "        edge = coedge_to_edge[coedge_index]\n",
    "        face = coedge_to_face[coedge_index]\n",
    "        if not face in faces_to_edges_dict:\n",
    "            faces_to_edges_dict[face] = set()\n",
    "        faces_to_edges_dict[face].add(edge)\n",
    "    faces_to_edges = []\n",
    "    for i in range(len(faces_to_edges_dict)):\n",
    "        assert i in faces_to_edges_dict\n",
    "        faces_to_edges.append(faces_to_edges_dict[i])\n",
    "    return faces_to_edges\n",
    "\n",
    "# This function pools edges features onto all faces which\n",
    "# are adjacent to that edge\n",
    "def pool_edge_data_onto_faces(data):\n",
    "    face_features = data[\"face_features\"]\n",
    "    edge_features = data[\"edge_features\"]\n",
    "    coedge_to_face = data[\"coedge_to_face\"] \n",
    "    coedge_to_edge = data[\"coedge_to_edge\"]\n",
    "    for edge in coedge_to_edge:\n",
    "        assert edge < edge_features.shape[0]\n",
    "    faces_to_edges = find_faces_to_edges(coedge_to_face, coedge_to_edge)\n",
    "    face_edge_features = []\n",
    "    for face_edge_set in faces_to_edges:\n",
    "        edge_features_for_face  = []\n",
    "        for edge in face_edge_set:\n",
    "            edge_features_for_face.append(edge_features[edge])\n",
    "        pooled_edge_features = np.max(np.stack(edge_features_for_face), axis = 0)\n",
    "        face_edge_features.append(pooled_edge_features)\n",
    "    assert len(face_edge_features) == face_features.shape[0]\n",
    "    face_edge_features = np.stack(face_edge_features)\n",
    "    return np.concatenate([face_features, face_edge_features], axis = 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4e037a1",
   "metadata": {},
   "source": [
    "Choose the model you would like to use as the query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e40e2aff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We found 25 example files\n"
     ]
    }
   ],
   "source": [
    "step_file_stems = [ f.stem for f in step_folder.glob(\"*.stp\")]\n",
    "print(f\"We found {len(step_file_stems)} example files\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "68fd06c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_index = 17"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85208f55",
   "metadata": {},
   "source": [
    "Now you will need to select some faces on the solid to continue.  \n",
    "Double click on each face to select it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "88f6c24a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Viewing example 21242_6c2af7c2_7\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "46ed98da437a4ab59fdf88273da9ef8c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(VBox(children=(HBox(children=(Checkbox(value=True, description='Axes', layout=Layout(height='au…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "file_stem = step_file_stems[query_index]\n",
    "print(f\"Viewing example {file_stem}\")\n",
    "viewer = JupyterSegmentationViewer(file_stem, step_folder, seg_folder=step_folder)\n",
    "viewer.view_solid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2d6321f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(viewer.selection_list) > 0, \"Please select some faces on the solid\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "72336c63",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Load raw feature data\n",
    "npz_folder = step_folder / \"temp_working\"\n",
    "npz_pathname = npz_folder / (file_stem + \".npz\")\n",
    "data = data_utils.load_npz_data(npz_pathname)\n",
    "data = standarize_data(data, feature_standardization)\n",
    "face_features = data[\"face_features\"]\n",
    "edge_features = data[\"edge_features\"]\n",
    "pooled_face_edge_features = pool_edge_data_onto_faces(data)\n",
    "\n",
    "assert pooled_face_edge_features.shape[0] == len(viewer.entity_mapper.face_map), \"Embedding size doesn't match solid\"\n",
    "selected_faces_features = pooled_face_edge_features[viewer.selection_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b5d63d32",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_files = list(npz_folder.glob(\"*.npz\"))\n",
    "\n",
    "min_dists_for_each_face = []\n",
    "sum_min_dists_for_each_solid = []\n",
    "for npz_file in all_files:\n",
    "        \n",
    "    # Load the input features\n",
    "    data = data_utils.load_npz_data(npz_file)\n",
    "    data = standarize_data(data, feature_standardization)\n",
    "    pooled_face_edge_features = pool_edge_data_onto_faces(data)\n",
    "        \n",
    "    min_dists = []\n",
    "    min_dists_for_each_face_in_this_solid = []\n",
    "    for selected_face_features in selected_faces_features:\n",
    "        vec = pooled_face_edge_features-selected_face_features\n",
    "        dist = np.linalg.norm(vec, axis=1)\n",
    "        \n",
    "        # Here we compute the distance from a given query face\n",
    "        # to each face in the solid\n",
    "        min_dists_for_each_face_in_this_solid.append(dist)\n",
    "        \n",
    "        # Then is metric is the distance from that given query face to\n",
    "        # the most similar face in the solid\n",
    "        min_dists.append(np.min(dist))\n",
    "    \n",
    "    # Now we need some way to say which solids is the best match over all\n",
    "    # We do this by summing the minimum distances to all the faces \n",
    "    sum_dists = np.sum(np.stack(min_dists), axis=0)\n",
    "    sum_min_dists_for_each_solid.append(sum_dists)\n",
    "    \n",
    "    # For display we also keep track of the min distance to each face\n",
    "    min_dists_for_each_face_in_this_solid = np.min(np.stack(min_dists_for_each_face_in_this_solid), axis=0)\n",
    "    min_dists_for_each_face.append(min_dists_for_each_face_in_this_solid)\n",
    "    \n",
    "sum_min_dists_for_each_solid = np.array(sum_min_dists_for_each_solid)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "664c4f00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Interval [0.0, 21.304689088927404]\n"
     ]
    }
   ],
   "source": [
    "# Now we will search for the top k best matches\n",
    "k = 5\n",
    "indices_of_smallest = np.argpartition(sum_min_dists_for_each_solid, kth=range(k))[:k]\n",
    "\n",
    "# We want to show how close each face in each solid is\n",
    "# To do this we want some kind of range which is computed here\n",
    "all_dists_top_k = []\n",
    "for index in indices_of_smallest:\n",
    "    all_dists_top_k.append(min_dists_for_each_face[index])\n",
    "all_dists_top_k = np.concatenate(all_dists_top_k)\n",
    "interval = [all_dists_top_k.min(), all_dists_top_k.max()]\n",
    "print(f\"Interval {interval}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5390a62a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "Close file 21242_6c2af7c2_7\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6116f9fdf0647c88f044a18bda8c6f0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(VBox(children=(HBox(children=(Checkbox(value=True, description='Axes', layout=Layout(height='au…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "Close file 21492_8bd34fc1_15\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab77ba55c8a54e38b94bd35e87e60d47",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(VBox(children=(HBox(children=(Checkbox(value=True, description='Axes', layout=Layout(height='au…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9\n",
      "Close file 56436_2a8fc254_3\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1539d1dd803c4b5cb8303b648f70176e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(VBox(children=(HBox(children=(Checkbox(value=True, description='Axes', layout=Layout(height='au…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n",
      "Close file 37117_89aac9d4_9\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d0c888ba4de463c9a279d00773dcd9a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(VBox(children=(HBox(children=(Checkbox(value=True, description='Axes', layout=Layout(height='au…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20\n",
      "Close file 44647_d83249a9_0\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "699f75403047422e96ea520910127174",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(VBox(children=(HBox(children=(Checkbox(value=True, description='Axes', layout=Layout(height='au…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Now as a sanity check we always expect to find the query as the best \n",
    "# match.  The other matching faces (red) should have a similar shape\n",
    "for i, index in enumerate(indices_of_smallest):\n",
    "    print(index)\n",
    "    close_file_stem = all_files[index].stem\n",
    "    print(f\"Close file {close_file_stem}\")\n",
    "    close_viewer = JupyterSegmentationViewer(close_file_stem, step_folder)\n",
    "    dists_to_view = min_dists_for_each_face[index]\n",
    "    close_viewer.display_faces_with_heatmap(dists_to_view, interval)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
